{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CatBoost and CoreML tutorial — Titanic dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CatBoost does support model export to Apple's [CoreML](https://developer.apple.com/machine-learning/) format, which lets you to easily embed ML models into applications on Apple's platforms.\n",
    "\n",
    "Currently, export of models with only float and one-hot features supported.\n",
    "\n",
    "This tutorial demonstrates exporting of CatBoost model trained on  [Titanic](https://www.kaggle.com/c/titanic/data) dataset to CoreML model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get titanic dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from catboost import Pool, CatBoost\n",
    "from catboost.datasets import titanic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = titanic()[0]\n",
    "X, y = train_df.drop('Survived', axis=1), train_df.Survived"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Name</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Ticket</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Cabin</th>\n",
       "      <th>Embarked</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>Braund, Mr. Owen Harris</td>\n",
       "      <td>male</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>A/5 21171</td>\n",
       "      <td>7.2500</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
       "      <td>female</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>PC 17599</td>\n",
       "      <td>71.2833</td>\n",
       "      <td>C85</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>Heikkinen, Miss. Laina</td>\n",
       "      <td>female</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>STON/O2. 3101282</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
       "      <td>female</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>113803</td>\n",
       "      <td>53.1000</td>\n",
       "      <td>C123</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>3</td>\n",
       "      <td>Allen, Mr. William Henry</td>\n",
       "      <td>male</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>373450</td>\n",
       "      <td>8.0500</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   PassengerId  Pclass                                               Name  \\\n",
       "0            1       3                            Braund, Mr. Owen Harris   \n",
       "1            2       1  Cumings, Mrs. John Bradley (Florence Briggs Th...   \n",
       "2            3       3                             Heikkinen, Miss. Laina   \n",
       "3            4       1       Futrelle, Mrs. Jacques Heath (Lily May Peel)   \n",
       "4            5       3                           Allen, Mr. William Henry   \n",
       "\n",
       "      Sex   Age  SibSp  Parch            Ticket     Fare Cabin Embarked  \n",
       "0    male  22.0      1      0         A/5 21171   7.2500   NaN        S  \n",
       "1  female  38.0      1      0          PC 17599  71.2833   C85        C  \n",
       "2  female  26.0      0      0  STON/O2. 3101282   7.9250   NaN        S  \n",
       "3  female  35.0      1      0            113803  53.1000  C123        S  \n",
       "4    male  35.0      0      0            373450   8.0500   NaN        S  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us drop Name and Ticket features as it doesn't make sence to one-hot these features, since there's a single object with this value, so there will be overfitting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "X.drop(['Name', 'Ticket'], axis=1, inplace=True)\n",
    "categorical_features_indices = np.where(X.dtypes != np.float)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_cat = (X.dtypes != float)\n",
    "for feature, feat_is_cat in is_cat.to_dict().items():\n",
    "    if feat_is_cat:\n",
    "        X[feature].fillna(\"NAN\", inplace=True)\n",
    "\n",
    "cat_features_index = np.where(is_cat)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Cabin</th>\n",
       "      <th>Embarked</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>7.2500</td>\n",
       "      <td>NAN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>71.2833</td>\n",
       "      <td>C85</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>female</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>NAN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>53.1000</td>\n",
       "      <td>C123</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8.0500</td>\n",
       "      <td>NAN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   PassengerId  Pclass     Sex   Age  SibSp  Parch     Fare Cabin Embarked\n",
       "0            1       3    male  22.0      1      0   7.2500   NAN        S\n",
       "1            2       1  female  38.0      1      0  71.2833   C85        C\n",
       "2            3       3  female  26.0      0      0   7.9250   NAN        S\n",
       "3            4       1  female  35.0      1      0  53.1000  C123        S\n",
       "4            5       3    male  35.0      0      0   8.0500   NAN        S"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pool = Pool(data=X, label=y, cat_features=cat_features_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Train the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Learning rate set to 0.016216\n",
      "0:\tlearn: 0.6862663\ttotal: 73.2ms\tremaining: 1m 13s\n",
      "100:\tlearn: 0.4272007\ttotal: 2.6s\tremaining: 23.2s\n",
      "200:\tlearn: 0.4044455\ttotal: 4.59s\tremaining: 18.3s\n",
      "300:\tlearn: 0.3928060\ttotal: 6.28s\tremaining: 14.6s\n",
      "400:\tlearn: 0.3852512\ttotal: 7.88s\tremaining: 11.8s\n",
      "500:\tlearn: 0.3750366\ttotal: 9.58s\tremaining: 9.54s\n",
      "600:\tlearn: 0.3624703\ttotal: 12.5s\tremaining: 8.28s\n",
      "700:\tlearn: 0.3493490\ttotal: 15s\tremaining: 6.39s\n",
      "800:\tlearn: 0.3390201\ttotal: 17.3s\tremaining: 4.3s\n",
      "900:\tlearn: 0.3301084\ttotal: 19.5s\tremaining: 2.14s\n",
      "999:\tlearn: 0.3224542\ttotal: 21.6s\tremaining: 0us\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<catboost.core.CatBoost at 0x110317050>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = CatBoost(params={'loss_function': 'Logloss', 'one_hot_max_size': 255, 'verbose': 100})\n",
    "model.fit(train_pool)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Predict probabilities:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_pool = Pool(data=X[0:1], cat_features=cat_features_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.88003974, 0.11996026]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(test_pool, prediction_type=\"Probability\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_model(\n",
    "    \"titanic.mlmodel\",\n",
    "    format=\"coreml\",\n",
    "    export_parameters={\n",
    "        'prediction_type': 'probability'\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All the features are named as \"feature_i\" where i is a feature number in the dataset starting from 0."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Now you can import saved model to XCode and use it directly from swift:\n",
    "\n",
    "```swift\n",
    "import CoreML\n",
    "\n",
    "let model = titanic()\n",
    "\n",
    "let passengerId = \"1\"\n",
    "let pclass = \"1\"\n",
    "let sex = \"female\"\n",
    "let age = 38.0\n",
    "let sibsp = \"1\"\n",
    "let parch = \"0\"\n",
    "let fare = 71.2833\n",
    "let cabin = \"C85\"\n",
    "let embarked = \"C\"\n",
    "\n",
    "guard let titanicOutput = try? model.prediction(feature_0: passengerId, feature_1: pclass, feature_2: sex, feature_3: age, feature_4: sibsp, feature_5: parch, feature_6: fare, feature_7: cabin, feature_8: embarked) else {\n",
    "            fatalError(\"Unexpected runtime error.\")\n",
    "        }\n",
    "\n",
    "print(String(\n",
    "    format: \"Probability of survival: %1.5f\",\n",
    "    titanicOutput.prediction[0].doubleValue\n",
    "))\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to practice, titanic model is easy to integrate into Apple's  [MarsHabitatPricer](https://developer.apple.com/documentation/coreml/integrating_a_core_ml_model_into_your_app) example project:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img \n",
    "src=\"https://imgur.com/f6G6ZrJ.jpg\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
